﻿#include "ColliderDefinitions.cginc"
#include "ContactHandling.cginc"
#include "DistanceFunctions.cginc"
#include "Simplex.cginc"
#include "Bounds.cginc"
#include "SolverParameters.cginc"
#include "Optimization.cginc"

#pragma kernel GenerateContacts

StructuredBuffer<float4> positions;
StructuredBuffer<quaternion> orientations;
StructuredBuffer<float4> principalRadii;
StructuredBuffer<float4> velocities;

StructuredBuffer<int> simplices;

StructuredBuffer<transform> transforms;
StructuredBuffer<shape> shapes;

// distance field data:
StructuredBuffer<DistanceFieldHeader> distanceFieldHeaders;
StructuredBuffer<DFNode> dfNodes;

StructuredBuffer<uint2> contactPairs;
StructuredBuffer<int> contactOffsetsPerType;

RWStructuredBuffer<contact> contacts;
RWStructuredBuffer<uint> dispatchBuffer;

StructuredBuffer<transform> worldToSolver;

uint maxContacts;
float deltaTime;

[numthreads(128, 1, 1)]
void GenerateContacts (uint3 id : SV_DispatchThreadID) 
{
    uint i = id.x;

    // entry #11 in the dispatch buffer is the amount of pairs for the first shape type.
    if (i >= dispatchBuffer[11 + 4 * SDF_SHAPE]) return; 

    int firstPair = contactOffsetsPerType[SDF_SHAPE];
    int simplexIndex = contactPairs[firstPair + i].x;
    int colliderIndex = contactPairs[firstPair + i].y;
    shape s = shapes[colliderIndex];

    if (s.dataIndex < 0) return;
    
    DistanceField dfShape;
    dfShape.colliderToSolver = worldToSolver[0].Multiply(transforms[colliderIndex]);
    dfShape.s = s;
    dfShape.distanceFieldHeaders = distanceFieldHeaders;
    dfShape.dfNodes = dfNodes;

    int simplexSize;
    int simplexStart = GetSimplexStartAndSize(simplexIndex, simplexSize);

    float4 simplexBary = BarycenterForSimplexOfSize(simplexSize);
    float4 simplexPoint;

    SurfacePoint colliderPoint = Optimize(dfShape, positions, orientations, principalRadii,
                                          simplices, simplexStart, simplexSize, simplexBary, simplexPoint, surfaceCollisionIterations, surfaceCollisionTolerance);

    float4 velocity = FLOAT4_ZERO;
    float simplexRadius = 0;
    for (int j = 0; j < simplexSize; ++j)
    {
        int particleIndex = simplices[simplexStart + j];
        simplexRadius += principalRadii[particleIndex].x * simplexBary[j];
        velocity += velocities[particleIndex] * simplexBary[j];
    }

    /*float4 rbVelocity = float4.zero;
    if (rigidbodyIndex >= 0)
        rbVelocity = BurstMath.GetRigidbodyVelocityAtPoint(rigidbodyIndex, colliderPoint.point, rigidbodies, solverToWorld);*/
        
    //float dAB = dot(simplexPoint - colliderPoint.pos, colliderPoint.normal);
    //float vel = dot(velocity     /*- rbVelocity*/,    colliderPoint.normal);

    //if (vel * deltaTime + dAB <= simplexRadius + s.contactOffset + collisionMargin)
    {
        uint count = contacts.IncrementCounter();
        if (count < maxContacts)
        {
            contact c = (contact)0;
           
            c.pointB = colliderPoint.pos;
            c.normal = colliderPoint.normal * dfShape.s.isInverted();
            c.pointA = simplexBary;
            c.bodyA = simplexIndex;
            c.bodyB = colliderIndex;

            contacts[count] = c;
            
            InterlockedMax(dispatchBuffer[0],(count + 1) / 128 + 1);
            InterlockedMax(dispatchBuffer[3], count + 1);
        }
    }
}